#!/usr/bin/env python3

import argparse
import os
import random
import pickle as pkl
import psutil
import time
import subprocess

benchmark_root = "../benchmark/"

# -------------------------------------------------------------------------------
# Util functions for collecting the time cost of Maxflash.

def clear_dict(dict_name):
    if os.path.exists(dict_name):
        os.system("rm -rf " + dict_name)
    os.system("mkdir " + dict_name)

# Keywords used to remove duplicate benchmarks.
# Example: phone_8-long-repeat.sl -> phone_8.sl
key_list = ["long", "short", "repeat", "-", "_", "small"]

# Remove all keywords from the name of a benchmark
def get_group_name(file_name):
    for key in key_list:
        file_name = file_name.replace(key, "")
    return file_name

# Train a n-gram model with n=/depth/.
def train_model(file_list, depth, domain, is_clear='N'):
    clear_dict("training_set")
    clear_dict("models")
    for file in file_list:
        os.system("cp " + file + " training_set")
    command = "python3 ../src/train/python/main.py training_set " + str(depth) + " models " + is_clear + " " + domain
    os.system(command)
    clear_dict("training_set")

# Collect all benchmarks with suffix /suffix/ from folder /name/.
def collect_benchmarks(name: str, suffix = ".sl"):
    group_map = {}
    benchmark_folder = os.path.join(benchmark_root, name)
    for benchmark_name in os.listdir(benchmark_folder):
        if suffix in benchmark_name:
            group_name = get_group_name(benchmark_name)
            if group_name not in group_map:
                group_map[group_name] = []
            group_map[group_name].append({"name": benchmark_name, "path": os.path.join(benchmark_folder, benchmark_name)})
    return list(group_map.values())

# Deal with the output file of MaxFlash
def get_result(file_name):
    try:
        with open(file_name, "r") as inp:
            lines = inp.readlines()
            if len(lines) == 0 or len(lines[0]) <= 2 or (len(lines[0]) <= 6 and "NULL" in lines[0]):
                return None
            return {"Program": lines[0][:-1], "Time": float(lines[1][:-1])}
    except FileNotFoundError:
        return None

# The number of benchmarks, and the number of finished benchmarks.
total_benchmark = 0
total_finished = 0

# Collect the result of the /pos/th thread, and store the result into dict /res/.
def finish(pos, ti, res):
    assert thread_pool[pos] is not None
    name = thread_pool[pos]["name"]
    res[name] = ti
    global total_finished
    total_finished += 1
    print("Finish ", name, ti, "Now " + str(total_finished) + "/" + str(total_benchmark))

# Check whether the /pos/th has been finished, and collect the result if it is.
def deal_with(pos, res):
    status = thread_pool[pos]["thread"].poll()
    if status is not None and status != 1:
        finish(pos, get_result(thread_pool[pos]["oup"]), res)
        thread_pool[pos] = None
        return

# A counter for assign non-repetitive names to the output files.
counter = 0

# Invoke Maxflash with benchmark /name/, command /command/ and output file /oup/.
# The result is stored in dict /res/.
def run_benchmark(name, command, oup, res):
    target_pos = None
    while target_pos is None:
        for pos in range(thread_num):
            if thread_pool[pos] is not None:
                deal_with(pos, res)
            if thread_pool[pos] is None:
                target_pos = pos
                break
        time.sleep(1)
    thread_pool[target_pos] = {}
    thread_pool[target_pos]["name"] = name
    thread_pool[target_pos]["oup"] = oup
    thread_pool[target_pos]["thread"] = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True)

# Collect all results from the thread pool.
def collect_all_result(res):
    while True:
        all_empty = True
        for pos in range(thread_num):
            if thread_pool[pos] is not None:
                deal_with(pos, res)
                all_empty = False
        if all_empty:
            break
        time.sleep(1)


# -------------------------------------------------------------------------------
# Util functions for collecting the memory cost of MaxFlash.
def get_process_info(target_id):
    p = psutil.process_iter()
    for r in p:
        if r.pid == target_id:
            return r

# Get the pid of the current invocation. It searches for the pid which
# (1) contains all keywords in /key/ in its name
# (2) contains no keyword in /forbid/ in its name.
# (3) has the closest pid with /pid/.
def get_pid(pid, command, key, forbid):
    p = psutil.process_iter()
    result = None
    for i in range(1000):
        for r in p:
            try:
                name = r.name()
            except:
                continue
            if name == "sh" or name == "timeout": continue
            if len(key) == 0 and name not in command: continue
            is_valid = True
            for k in key: 
                if k not in name: is_valid = False
            for k in forbid: 
                if k in name: is_valid = False
            if not is_valid: continue
            if result is None or abs(r.pid - pid) < abs(result - pid):
                print("update ", name)
                result = r.pid
        if result is not None: break
    return result

# Get the peak memory cost of the process with pid=/target_id/.
def get_memory_result(target_id, p, command, key, forbid):
    name = None
    result = None
    while True:
        while target_id is None:
            status = p.poll()
            if status is None or status == 1:
                target_id = get_pid(p.pid, command, key, forbid)
            else:
                return result
        process_info = get_process_info(target_id)
        if process_info is None:
            return result
        if name is None: name = process_info.name
        if name != process_info.name:
            return result
        try:
            memory = process_info.memory_info().rss
        except:
            return result
        if result is None:
            result = memory
        result = max(result, memory)

# Get the peak memory cost of MaxFlash on benchmark /name/.
def get_memory(name, command, res, time_result, keyw=[], forbid=[]):
    ti = time_result[name]
    if ti is not None and getTime(ti) <= time_limit:
        p = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True)
        print(name, ti)
        res[name] = get_memory_result(get_pid(p.pid, command, keyw, forbid), p, command, keyw, forbid)
    else:
        res[name] = None
    global total_finished
    total_finished += 1
    print("Finish ", name, res[name], "Now " + str(total_finished) + "/" + str(total_benchmark))

# -------------------------------------------------------------------------------
# Run MaxFlash on all benchmarks in /file_list/.
# If /time_result/ None, it will collect the time cost, otherwise it will collect the memory cost.
def run_all(file_list, depth, res, domain, time_result=None):
    global counter
    global total_benchmark
    global total_finished
    total_benchmark = len(file_list)
    total_finished = 0
    counter = 0
    clear_dict("temp")
    for benchmark_file in file_list:
        counter += 1
        oup_file = "temp/" + str(counter) + ".res"
        log_file = "temp/" + str(counter) + ".log"
        model_file = "models/model@depth" + str(depth) + ".json"
        command = ['ulimit -v ' + str(memory_limit) + ';' + "timeout " + str(time_limit) + " " + bin_file,
                   "--model=" + model_file, "--spec=" + benchmark_file, "--oup=" + oup_file, "--log=" + log_file,
                   "--type=" + domain]
        command = " ".join(command)
        print(command)
        if time_result is None:
            run_benchmark(benchmark_file.split("/")[-1], command, oup_file, res)
        else:
            print("get memory")
            get_memory(benchmark_file.split("/")[-1], command, res, time_result)
    if time_result is None: collect_all_result(res)
    clear_dict("temp")

# Get the time cost from a result
def getTime(data):
    if type(data) == dict:
        return data["Time"]
    else:
        return data

# Run /n/-fold cross-validation on benchmarks in /benchmark_list/ with /depth/-gram model.
# If /time_cost/ is None, it will collects the time cost, otherwise it will collect the memory cost.
def run_fold(n, benchmark_list, depth, domain, time_cost=None):
    tracks = [[] for _ in range(n)]
    random.shuffle(benchmark_list)
    for (i, group) in enumerate(benchmark_list):
        tracks[i % n] += group
    result = {}
    for i in range(n):
        print("Fold " + str(i) + "/" + str(n))
        test = []
        train = []
        for j in range(n):
            for benchmark in tracks[j]:
                if j == i:
                    test.append(benchmark["path"])
                else:
                    train.append(benchmark["path"])
        train_model(train, depth, domain, 'Y' if i == 0 else 'N')
        run_all(test, depth, result, domain, time_cost)
    return result

# -------------------------------------------------------------------------------
# Module for loading/saving results.
# The dictionary stores all existing results.
cache_dir = "result_cache"

def load_cache(file):
    if not os.path.exists(file):
        return None
    with open(file, "rb") as inp:
        return pkl.load(inp)

def save_cache(file, result):
    with open(file, "wb") as oup:
        pkl.dump(result, oup)

# -------------------------------------------------------------------------------
# run experiment on domain /domain/ with /depth/-gram model. Use /fold_num/-fold cross validation.
def run_experiment(domain, depth, fold_num):
    benchmark_list = collect_benchmarks(domain, ".sl" if domain == "string" else "")
    time_file = os.path.join(cache_dir, domain + str(depth) + "-time.pkl")
    memory_file = os.path.join(cache_dir, domain + str(depth) + "-memory.pkl")
    output_file = domain + str(depth) + "-result.csv"
    time_result = load_cache(time_file)
    if fold_num == 0: fold_num = len(benchmark_list)
    if time_result is None:
        time_result = run_fold(fold_num, benchmark_list, depth, domain)
        save_cache(time_file, time_result)
    memory_result = load_cache(memory_file)
    if memory_result is None:
        memory_result = run_fold(fold_num, benchmark_list, depth, domain, time_result)
        save_cache(memory_file, memory_result)
    with open(output_file, "w") as oup:
        oup.write("Name, Time(s), Memory(MB)\n")
        for name, time_cost in time_result.items():
            time_cost = getTime(time_cost)
            if time_cost is None or time_cost >= time_limit:
                time_cost, memory_cost = "timeout", "skipped"
            else:
                memory_cost = memory_result[name]
                if memory_cost is None: memory_cost = "skipped"
                else: memory_cost = memory_cost / 1024 / 1024
            oup.write(name + ", " + str(time_cost) + ", " + str(memory_cost) + "\n")

# Parse args
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-tl', '--timeLimit', type=float, default=0.5)
    parser.add_argument('-ml', '--memoryLimit', type=int, default=8)
    parser.add_argument('-tn', '--threadNum', type=int, default=8)
    parser.add_argument('-fn', '--foldNum', type=int, default=0)    # 0 represents leave-one-out cross-validation
    parser.add_argument('-d', '--depth', type=int, default=1)
    parser.add_argument('-t', '--type', type=str, choices=['all', 'string', 'matrix'], default='all')
    parser.add_argument('-e', '--exe', type=str, default="../build/run")
    return parser.parse_args()

if __name__ == "__main__":
    if not os.path.exists(cache_dir):
        os.system("mkdir " + cache_dir)
    args = parse_args()
    time_limit = args.timeLimit
    memory_limit = args.memoryLimit * 1024 * 1024
    thread_num = args.threadNum
    thread_pool = [None for _ in range(thread_num)]
    fold_num = args.foldNum
    depth = args.depth
    domain_list = ["string", "matrix"] if args.type == "all" else [args.type]
    bin_file = args.exe
    for domain in domain_list:
        run_experiment(domain, depth, fold_num)